import torch
import numpy as np
import matplotlib.pyplot as plt
import util.landscape.util as util
import imageio
import os
import glob
from PIL import Image
from functools import partial


def frame(x, y, Z, trajectories, args=None):
    X, Y = np.meshgrid(x, y)
    plt.figure(figsize=(10, 8))

    # 绘制热力图
    heatmap = plt.imshow(Z, extent=(args.xmin, args.xmax, args.ymin, args.ymax), 
                         vmax=args.vmax, vmin=args.vmin, origin='lower', 
                         cmap='coolwarm', interpolation='bilinear', alpha=0.95)
    
    # 绘制等高线，增加宽度、透明度和颜色
    contour_plot = plt.contour(X, Y, Z, cmap='coolwarm', 
                               levels=np.arange(args.vmin, args.vmax, args.vlevel), 
                               alpha=1.0, linewidths=3)  # 增加线宽度
    plt.clabel(contour_plot, inline=True, fontsize=8)  # 添加等高线标签

    # 添加颜色条
    plt.colorbar(heatmap)

    # 设置子图的边距
    plt.subplots_adjust(left=0.03, right=1, top=0.95, bottom=0.05) 

    # 绘制路径，只绘制在范围内的路径点
    if trajectories:
        for name, traj in trajectories.items():
            traj = np.array(traj[:,args.current.epoch,args.current.epoch+args.fpp_scope])
            # 只保留在范围内的路径点
            valid_indices = np.where((traj[0] >= args.xmin) & (traj[0] <= args.xmax) &
                                      (traj[1] >= args.ymin) & (traj[1] <= args.ymax))[0]
            valid_traj = traj[:, valid_indices]  # 过滤掉超出范围的路径点
            plt.plot(valid_traj[0], valid_traj[1], color="#000000", label=f'scope:{args.fpp_scope}', alpha=0.9)
        
    # 标记起始点
    plt.scatter(0, 0, color='#00ff00', s=20, label=f'{args.current.epoch}', zorder=16)  

    # 设置图表标题和轴标签
    plt.title("Animation with Trajectories")
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')

    # 设置图例
    plt.legend()

    # 调整坐标轴的比例
    if args.auto_fit:
        ax = plt.gca()  
        ax.set_aspect(1.0 * ((args.xmax-args.xmin) /(args.ymax-args.ymin)))

    # 保存图像
    dircty = os.path.join(f'../image/{args.task_name}/', 
                          args.id + f"_{args.xmin}_{args.xmax}_{args.xnum}_{args.ymin}_{args.ymax}_{args.ynum}_{args.vmin}_{args.vmax}_{args.vlevel}")
    if not os.path.exists(dircty):
        os.makedirs(dircty)
    plt.savefig(os.path.join(dircty, f"frame_{args.current_epoch}.png"))

    # 显示图像
    plt.show()

# def create_video_from_frames(args, fps=2):

#     image_files = sorted(glob.glob(os.path.join(os.path.join(f'../image/{args.task_name}/', 
#                           args.id + f"_{args.xmin}_{args.xmax}_{args.xnum}_{args.ymin}_{args.ymax}_{args.ynum}_{args.vmin}_{args.vmax}_{args.vlevel}"), "*.png")), key=util.extract_number)

#     if not image_files:
#         print("No PNG files found in the folder.")
#         return

#     with imageio.get_writer(os.path.join(f"../video/{args.task_name}", args.id + f"_{args.xmin}_{args.xmax}_{args.xnum}_{args.ymin}_{args.ymax}_{args.ynum}_{args.vmin}_{args.vmax}_{args.vlevel}.mp4" ), fps=fps) as writer:
#         for image_file in image_files:
#             img = Image.open(image_file)
#             writer.append_data(img)
#         print(f"Video generated!")


def create_video_from_frames(args, fps=2):
    image_files = sorted(glob.glob(os.path.join(f'../image/{args.task_name}/', 
                          args.id + f"_{args.xmin}_{args.xmax}_{args.xnum}_{args.ymin}_{args.ymax}_{args.ynum}_{args.vmin}_{args.vmax}_{args.vlevel}", "*.png")), key=partial(util.extract_number, order=-1))
    if not image_files:
        print("No PNG files found in the folder.")
        return

    # 设置视频保存路径
    video_path = os.path.join(f"../video/{args.task_name}", args.id + f"_{args.xmin}_{args.xmax}_{args.xnum}_{args.ymin}_{args.ymax}_{args.ynum}_{args.vmin}_{args.vmax}_{args.vlevel}.mp4")

    with imageio.get_writer(video_path, fps=fps, codec='libx264') as writer:
        for image_file in image_files:
            # 打开图片并转换为 numpy 数组
            img = Image.open(image_file)
            img_array = np.array(img)  # 转换为 ndarray
            
            writer.append_data(img_array)  # 传入 ndarray
        print(f"Video generated!")





